{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.00130444, -0.03117937, -0.0422061 , -0.01014696], dtype=float32)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('CartPole-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(action)\n",
    "        done = terminated or truncated\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            done = True\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAF7CAYAAAD4/3BBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAApWElEQVR4nO3df3SU5Z338c9MfgyEMBMDJJNIgigUiBBsQcOsrWuXlADRlTXuUcsKdnnkSBOPGmsxXatid42Le9YfXYRznu2Ke46U1h7RSgWLIKHWCJiS8kuzwkMbXDIJlSczIZpJMnM9f/gw21F+ZELIXEPer3PuczL3dc19f+/r5JAP1/3LYYwxAgAAsIgz0QUAAAB8EQEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFgnoQFl1apVuuyyyzRs2DCVlJRo165diSwHAABYImEB5Wc/+5mqq6v16KOP6ne/+52mT5+usrIytbW1JaokAABgCUeiXhZYUlKiq6++Wv/2b/8mSYpEIiooKNA999yjhx56KBElAQAAS6QmYqfd3d1qaGhQTU1NdJ3T6VRpaanq6+u/1D8UCikUCkU/RyIRnThxQqNGjZLD4RiUmgEAwPkxxqijo0P5+flyOs9+EichAeVPf/qTwuGwcnNzY9bn5ubqww8//FL/2tparVixYrDKAwAAF9DRo0c1duzYs/ZJSECJV01Njaqrq6OfA4GACgsLdfToUbnd7gRWBgAA+ioYDKqgoEAjR448Z9+EBJTRo0crJSVFra2tMetbW1vl9Xq/1N/lcsnlcn1pvdvtJqAAAJBk+nJ5RkLu4klPT9eMGTO0devW6LpIJKKtW7fK5/MloiQAAGCRhJ3iqa6u1uLFizVz5kxdc801euaZZ9TZ2anvfOc7iSoJAABYImEB5dZbb9Xx48f1yCOPyO/366qrrtLmzZu/dOEsAAAYehL2HJTzEQwG5fF4FAgEuAYFAIAkEc/fb97FAwAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgnQEPKI899pgcDkfMMnny5Gh7V1eXKisrNWrUKGVmZqqiokKtra0DXQYAAEhiF2QG5corr1RLS0t0eeedd6Jt999/v15//XW9/PLLqqur07Fjx3TzzTdfiDIAAECSSr0gG01Nldfr/dL6QCCgn/zkJ1q3bp3+6q/+SpL0wgsvaMqUKXrvvfc0a9asC1EOAABIMhdkBuWjjz5Sfn6+Lr/8ci1cuFDNzc2SpIaGBvX09Ki0tDTad/LkySosLFR9ff0ZtxcKhRQMBmMWAABw8RrwgFJSUqK1a9dq8+bNWr16tY4cOaJvfOMb6ujokN/vV3p6urKysmK+k5ubK7/ff8Zt1tbWyuPxRJeCgoKBLhsAAFhkwE/xzJs3L/pzcXGxSkpKNG7cOP385z/X8OHD+7XNmpoaVVdXRz8Hg0FCCgAAF7ELfptxVlaWvvKVr+jQoUPyer3q7u5We3t7TJ/W1tbTXrNyisvlktvtjlkAAMDF64IHlJMnT+rw4cPKy8vTjBkzlJaWpq1bt0bbm5qa1NzcLJ/Pd6FLAQAASWLAT/F873vf04033qhx48bp2LFjevTRR5WSkqLbb79dHo9HS5YsUXV1tbKzs+V2u3XPPffI5/NxBw8AAIga8IDy8ccf6/bbb9cnn3yiMWPG6Otf/7ree+89jRkzRpL09NNPy+l0qqKiQqFQSGVlZXr++ecHugwAAJDEHMYYk+gi4hUMBuXxeBQIBLgeBQCAJBHP32/exQMAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsE7cAWXHjh268cYblZ+fL4fDoVdffTWm3RijRx55RHl5eRo+fLhKS0v10UcfxfQ5ceKEFi5cKLfbraysLC1ZskQnT548rwMBAAAXj7gDSmdnp6ZPn65Vq1adtn3lypV67rnntGbNGu3cuVMjRoxQWVmZurq6on0WLlyoAwcOaMuWLdq4caN27NihpUuX9v8oAADARcVhjDH9/rLDoQ0bNmjBggWSPp89yc/P1wMPPKDvfe97kqRAIKDc3FytXbtWt912mz744AMVFRVp9+7dmjlzpiRp8+bNmj9/vj7++GPl5+efc7/BYFAej0eBQEBut7u/5QMAgEEUz9/vAb0G5ciRI/L7/SotLY2u83g8KikpUX19vSSpvr5eWVlZ0XAiSaWlpXI6ndq5c+dptxsKhRQMBmMWAABw8RrQgOL3+yVJubm5Metzc3OjbX6/Xzk5OTHtqampys7Ojvb5otraWnk8nuhSUFAwkGUDAADLJMVdPDU1NQoEAtHl6NGjiS4JAABcQAMaULxerySptbU1Zn1ra2u0zev1qq2tLaa9t7dXJ06ciPb5IpfLJbfbHbMAAICL14AGlPHjx8vr9Wrr1q3RdcFgUDt37pTP55Mk+Xw+tbe3q6GhIdpn27ZtikQiKikpGchyAABAkkqN9wsnT57UoUOHop+PHDmixsZGZWdnq7CwUPfdd5/+8R//URMnTtT48eP1wx/+UPn5+dE7faZMmaK5c+fqrrvu0po1a9TT06OqqirddtttfbqDBwAAXPziDijvv/++vvnNb0Y/V1dXS5IWL16stWvX6vvf/746Ozu1dOlStbe36+tf/7o2b96sYcOGRb/z0ksvqaqqSrNnz5bT6VRFRYWee+65ATgcAABwMTiv56AkCs9BAQAg+STsOSgAAAADgYACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6cQeUHTt26MYbb1R+fr4cDodeffXVmPY777xTDocjZpk7d25MnxMnTmjhwoVyu93KysrSkiVLdPLkyfM6EAAAcPGIO6B0dnZq+vTpWrVq1Rn7zJ07Vy0tLdHlpz/9aUz7woULdeDAAW3ZskUbN27Ujh07tHTp0virBwAAF6XUeL8wb948zZs376x9XC6XvF7vads++OADbd68Wbt379bMmTMlST/+8Y81f/58/cu//Ivy8/PjLQkAAFxkLsg1KNu3b1dOTo4mTZqkZcuW6ZNPPom21dfXKysrKxpOJKm0tFROp1M7d+487fZCoZCCwWDMAgAALl4DHlDmzp2r//zP/9TWrVv1z//8z6qrq9O8efMUDoclSX6/Xzk5OTHfSU1NVXZ2tvx+/2m3WVtbK4/HE10KCgoGumwAAGCRuE/xnMttt90W/XnatGkqLi7WFVdcoe3bt2v27Nn92mZNTY2qq6ujn4PBICEFAICL2AW/zfjyyy/X6NGjdejQIUmS1+tVW1tbTJ/e3l6dOHHijNetuFwuud3umAUAAFy8LnhA+fjjj/XJJ58oLy9PkuTz+dTe3q6GhoZon23btikSiaikpORClwMAAJJA3Kd4Tp48GZ0NkaQjR46osbFR2dnZys7O1ooVK1RRUSGv16vDhw/r+9//viZMmKCysjJJ0pQpUzR37lzdddddWrNmjXp6elRVVaXbbruNO3gAAIAkyWGMMfF8Yfv27frmN7/5pfWLFy/W6tWrtWDBAu3Zs0ft7e3Kz8/XnDlz9KMf/Ui5ubnRvidOnFBVVZVef/11OZ1OVVRU6LnnnlNmZmafaggGg/J4PAoEApzuAQAgScTz9zvugGIDAgoAAMknnr/fvIsHAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKwT98sCAaC/wj1dOvzW/z5rH4fDoQlzviuHk/8/AUMZAQXAoDGRsALN+87eyeFUuLdbqenDBqcoAFbivygArBPp6Up0CQASjIACwDrhnlCiSwCQYAQUAJYxzKAAIKAAsA8zKAAIKACswwwKAAIKALsYKdxNQAGGOgIKAOuEeznFAwx1BBQAg8bhTFGmd8I5ehkFmvcPSj0A7EVAATBoHI4UZYwed85+n534eBCqAWAzAgqAweNwKCXNlegqACQBAgqAQeVMS090CQCSAAEFwKBxOBxypjKDAuDcCCgABlUKAQVAHxBQAAwih5xcgwKgDwgoAAaPg4ACoG8IKAAGVV9P8RhjLnAlAGxGQAEwaBwOhxzOPvyzY4xMJHLhCwJgLQIKAOsYYxThcffAkBZXQKmtrdXVV1+tkSNHKicnRwsWLFBTU1NMn66uLlVWVmrUqFHKzMxURUWFWltbY/o0NzervLxcGRkZysnJ0YMPPqje3t7zPxoAFwmjSG9PoosAkEBxBZS6ujpVVlbqvffe05YtW9TT06M5c+aos7Mz2uf+++/X66+/rpdffll1dXU6duyYbr755mh7OBxWeXm5uru79e677+rFF1/U2rVr9cgjjwzcUQFIasZEmEEBhjiHOY8r0Y4fP66cnBzV1dXpuuuuUyAQ0JgxY7Ru3TrdcsstkqQPP/xQU6ZMUX19vWbNmqVNmzbphhtu0LFjx5SbmytJWrNmjZYvX67jx48rPf3cT5kMBoPyeDwKBAJyu939LR9AAnS0fKQPf/nUWfukZXj0lfn3KmPU2EGqCsBgiOfv93ldgxIIBCRJ2dnZkqSGhgb19PSotLQ02mfy5MkqLCxUfX29JKm+vl7Tpk2LhhNJKisrUzAY1IEDB067n1AopGAwGLMAuHh9fg1Kd6LLAJBA/Q4okUhE9913n6699lpNnTpVkuT3+5Wenq6srKyYvrm5ufL7/dE+fx5OTrWfajud2tpaeTye6FJQUNDfsgEkAwIKMOT1O6BUVlZq//79Wr9+/UDWc1o1NTUKBALR5ejRoxd8nwASiYtkgaEutT9fqqqq0saNG7Vjxw6NHfs/54i9Xq+6u7vV3t4eM4vS2toqr9cb7bNr166Y7Z26y+dUny9yuVxyuXj6JHAxcDhT5Ux1nfUiWBMJq+czTuUCQ1lcMyjGGFVVVWnDhg3atm2bxo8fH9M+Y8YMpaWlaevWrdF1TU1Nam5uls/nkyT5fD7t27dPbW1t0T5btmyR2+1WUVHR+RwLgCSQnpkl99gpZ+0T7v5MgeZ9g1QRABvFNYNSWVmpdevW6bXXXtPIkSOj14x4PB4NHz5cHo9HS5YsUXV1tbKzs+V2u3XPPffI5/Np1qxZkqQ5c+aoqKhId9xxh1auXCm/36+HH35YlZWVzJIAQ4DDkSJnSlqiywBgubgCyurVqyVJ119/fcz6F154QXfeeack6emnn5bT6VRFRYVCoZDKysr0/PPPR/umpKRo48aNWrZsmXw+n0aMGKHFixfr8ccfP78jAZAcHE45CCgAzuG8noOSKDwHBUhevaFP9fGuDTp+sO6s/S4Z/zVNmHP3IFUFYDAM2nNQACBeDoeTUzwAzomAAmBQOZxOOVMJKADOjoACYHAxgwKgDwgoAAaVw9n3i2ST8BI5AAOEgAJgkDnkSEk5Zy9jjERAAYYsAgqAQeVwOCQ5ztnPRMKKRHovfEEArERAAWAlY8IyYQIKMFQRUABYyUQIKMBQRkABYCVO8QBDGwEFgJU+n0HpSXQZABKEgALASiYSVoRTPMCQRUABYCUTiXANCjCEEVAAWIkZFGBoI6AAGHQjxozTsEvyz9onFGhVZ9uRQaoIgG0IKAAGXUracKWkuc7a5/MZFC6SBYYqAgqAQedMSZXDee7H3QMYuggoAAadg4AC4BwIKAAGncOZQkABcFYEFACD7vMZlNRElwHAYgQUAIPOmZIqRwozKADOjIACYNA5nKlycooHwFkQUAAMOoczRXL0IaAYI2PMhS8IgHUIKAAGncPhkMNx7n4m0iuZyIUvCIB1CCgArBUJ98hECCjAUERAAWCtSG+PjAknugwACUBAAWAtE+5lBgUYoggoAKz1+SkeZlCAoYiAAsBaJtwrw0WywJBEQAFgrUi4R+IUDzAkEVAAJESqK/Oc7+Pp+TSgcE/XIFUEwCYEFAAJMfLSyUpxjThrn862I+o++X8HqSIANokroNTW1urqq6/WyJEjlZOTowULFqipqSmmz/XXX///H8L0P8vdd98d06e5uVnl5eXKyMhQTk6OHnzwQfX29p7/0QBIGs7UdDn68rQ2AENSXK8TraurU2Vlpa6++mr19vbqBz/4gebMmaODBw9qxIj/+Z/QXXfdpccffzz6OSMjI/pzOBxWeXm5vF6v3n33XbW0tGjRokVKS0vTE088MQCHBCAZOFPTJQeTuABOL66Asnnz5pjPa9euVU5OjhoaGnTddddF12dkZMjr9Z52G7/+9a918OBBvfXWW8rNzdVVV12lH/3oR1q+fLkee+wxpaen9+MwACQbZ2oaMygAzui8/vsSCAQkSdnZ2THrX3rpJY0ePVpTp05VTU2NPv3002hbfX29pk2bptzc3Oi6srIyBYNBHThw4LT7CYVCCgaDMQuA5MYMCoCziWsG5c9FIhHdd999uvbaazV16tTo+m9/+9saN26c8vPztXfvXi1fvlxNTU165ZVXJEl+vz8mnEiKfvb7/afdV21trVasWNHfUgFYiBkUAGfT74BSWVmp/fv365133olZv3Tp0ujP06ZNU15enmbPnq3Dhw/riiuu6Ne+ampqVF1dHf0cDAZVUFDQv8IBWMGZwgwKgDPr178OVVVV2rhxo95++22NHTv2rH1LSkokSYcOHZIkeb1etba2xvQ59flM1624XC653e6YBUBy4y4eAGcTV0AxxqiqqkobNmzQtm3bNH78+HN+p7GxUZKUl5cnSfL5fNq3b5/a2tqifbZs2SK3262ioqJ4ygGQxJypaX2cQTEyxlzwegDYJa5TPJWVlVq3bp1ee+01jRw5MnrNiMfj0fDhw3X48GGtW7dO8+fP16hRo7R3717df//9uu6661RcXCxJmjNnjoqKinTHHXdo5cqV8vv9evjhh1VZWSmXyzXwRwjASg6HU32ZP4mEey54LQDsE9cMyurVqxUIBHT99dcrLy8vuvzsZz+TJKWnp+utt97SnDlzNHnyZD3wwAOqqKjQ66+/Ht1GSkqKNm7cqJSUFPl8Pv3d3/2dFi1aFPPcFAA4JdLbLTGDAgw5cc2gnGuataCgQHV1defczrhx4/TGG2/Es2sAQ1S4p1sSAQUYariEHoDVIr0hrkEBhiACCgCrRXqZQQGGIgIKAKtFekLkE2AIIqAAsFqkt1uGhAIMOQQUAFYL94S4iwcYgggoABJm9JRvnPNhbZ98VK9IT9cgVQTAFgQUAAmTNizznH1MuJcTPMAQREABkDDONBfv4wFwWgQUAAnjTOP1FgBOj4ACIGFSCCgAzoCAAiBhnKkEFACnR0ABkDApacOkPr3TGMBQQ0ABkDBcgwLgTAgoABImJW0YEygATouAAiBhHM6UPvUz4R7eaAwMMQQUANYL94QSXQKAQUZAAWC9SDePugeGGgIKAOsxgwIMPQQUANYL87JAYMhJTXQBAJJXJBJRJBLp/wb6eOFrT+hThcPhfu/G4XAoJaVvF+QCsAMzKAD67Z/+6Z80fPjwfi8ZIzLU1XXu2ZF7q757Xvv527/920EYDQADiRkUAP0WiUTU29t7wffjSnOe137OZ/YFQGIwgwIgoX5/yC9J6gy79cfPpuijzhk68tlUtfeMjva5ZvKliSoPQIIwgwIgoRqaWjRp4lXa3/ENdUY8Cps0OdWrjJQOTcj4nfJc/0fFV+QmukwAg4yAAiChAqHh2h2Yrx4zLLouojSdDGdr/8nrlOboUqaaE1ghgETgFA+AhCq6bmVMOPlzvSZdu4PlCkUyBrkqAIlGQAGQYOd6WyBvEwSGIgIKAACwDgEFAABYh4ACIKHWv/i/5NTpn3HiUFhfG/lruZyfDXJVABItroCyevVqFRcXy+12y+12y+fzadOmTdH2rq4uVVZWatSoUcrMzFRFRYVaW1tjttHc3Kzy8nJlZGQoJydHDz744KA86AmAnT7t+G9d49moDGdATvVIMnKoV8OcHZqa+RvlpP9BUt8eiQ/g4hHXbcZjx47Vk08+qYkTJ8oYoxdffFE33XST9uzZoyuvvFL333+/fvWrX+nll1+Wx+NRVVWVbr75Zv32t7+V9PnTHMvLy+X1evXuu++qpaVFixYtUlpamp544okLcoAA7Bb8NKS6nTvU0btPrd2XqSuSqXRHl0anH1Ugza/3JfX2nsf7fgAkJYcxfXxb1xlkZ2frqaee0i233KIxY8Zo3bp1uuWWWyRJH374oaZMmaL6+nrNmjVLmzZt0g033KBjx44pN/fzBy+tWbNGy5cv1/Hjx5Went6nfQaDQXk8Ht155519/g6AgdfQ0KCGhoZEl3FO48aNU1lZWaLLAIa87u5urV27VoFAQG63+6x9+/2gtnA4rJdfflmdnZ3y+XxqaGhQT0+PSktLo30mT56swsLCaECpr6/XtGnTouFEksrKyrRs2TIdOHBAX/3qV0+7r1AopFAoFP0cDAYlSXfccYcyMzP7ewgAzpMxJikCSmFhoZYsWZLoMoAh7+TJk1q7dm2f+sYdUPbt2yefz6euri5lZmZqw4YNKioqUmNjo9LT05WVlRXTPzc3V37/5+/a8Pv9MeHkVPuptjOpra3VihUrvrR+5syZ50xgAC6cP78GzWaXXHKJrrnmmkSXAQx5pyYY+iLuu3gmTZqkxsZG7dy5U8uWLdPixYt18ODBeDcTl5qaGgUCgehy9OjRC7o/AACQWHHPoKSnp2vChAmSpBkzZmj37t169tlndeutt6q7u1vt7e0xsyitra3yer2SJK/Xq127dsVs79RdPqf6nI7L5ZLL5Yq3VAAAkKTO+zkokUhEoVBIM2bMUFpamrZu3Rpta2pqUnNzs3w+nyTJ5/Np3759amtri/bZsmWL3G63ioqKzrcUAABwkYhrBqWmpkbz5s1TYWGhOjo6tG7dOm3fvl1vvvmmPB6PlixZourqamVnZ8vtduuee+6Rz+fTrFmzJElz5sxRUVGR7rjjDq1cuVJ+v18PP/ywKisrmSEBAABRcQWUtrY2LVq0SC0tLfJ4PCouLtabb76pb33rW5Kkp59+Wk6nUxUVFQqFQiorK9Pzzz8f/X5KSoo2btyoZcuWyefzacSIEVq8eLEef/zxgT0qAACQ1OIKKD/5yU/O2j5s2DCtWrVKq1atOmOfcePG6Y033ohntwAAYIjhXTwAAMA6BBQAAGAdAgoAALAOAQUAAFin3+/iAYDJkydrwYIFiS7jnHjMPZB8zvttxolw6m3GfXkbIgAAsEM8f785xQMAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFgnroCyevVqFRcXy+12y+12y+fzadOmTdH266+/Xg6HI2a5++67Y7bR3Nys8vJyZWRkKCcnRw8++KB6e3sH5mgAAMBFITWezmPHjtWTTz6piRMnyhijF198UTfddJP27NmjK6+8UpJ011136fHHH49+JyMjI/pzOBxWeXm5vF6v3n33XbW0tGjRokVKS0vTE088MUCHBAAAkp3DGGPOZwPZ2dl66qmntGTJEl1//fW66qqr9Mwzz5y276ZNm3TDDTfo2LFjys3NlSStWbNGy5cv1/Hjx5Went6nfQaDQXk8HgUCAbnd7vMpHwAADJJ4/n73+xqUcDis9evXq7OzUz6fL7r+pZde0ujRozV16lTV1NTo008/jbbV19dr2rRp0XAiSWVlZQoGgzpw4MAZ9xUKhRQMBmMWAABw8YrrFI8k7du3Tz6fT11dXcrMzNSGDRtUVFQkSfr2t7+tcePGKT8/X3v37tXy5cvV1NSkV155RZLk9/tjwomk6Ge/33/GfdbW1mrFihXxlgoAAJJU3AFl0qRJamxsVCAQ0C9+8QstXrxYdXV1Kioq0tKlS6P9pk2bpry8PM2ePVuHDx/WFVdc0e8ia2pqVF1dHf0cDAZVUFDQ7+0BAAC7xX2KJz09XRMmTNCMGTNUW1ur6dOn69lnnz1t35KSEknSoUOHJEler1etra0xfU599nq9Z9yny+WK3jl0agEAABev834OSiQSUSgUOm1bY2OjJCkvL0+S5PP5tG/fPrW1tUX7bNmyRW63O3qaCAAAIK5TPDU1NZo3b54KCwvV0dGhdevWafv27XrzzTd1+PBhrVu3TvPnz9eoUaO0d+9e3X///bruuutUXFwsSZozZ46Kiop0xx13aOXKlfL7/Xr44YdVWVkpl8t1QQ4QAAAkn7gCSltbmxYtWqSWlhZ5PB4VFxfrzTff1Le+9S0dPXpUb731lp555hl1dnaqoKBAFRUVevjhh6PfT0lJ0caNG7Vs2TL5fD6NGDFCixcvjnluCgAAwHk/ByUReA4KAADJZ1CegwIAAHChEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOukJrqA/jDGSJKCwWCCKwEAAH116u/2qb/jZ5OUAaWjo0OSVFBQkOBKAABAvDo6OuTxeM7ax2H6EmMsE4lE1NTUpKKiIh09elRutzvRJSWtYDCogoICxnEAMJYDh7EcGIzjwGEsB4YxRh0dHcrPz5fTefarTJJyBsXpdOrSSy+VJLndbn5ZBgDjOHAYy4HDWA4MxnHgMJbn71wzJ6dwkSwAALAOAQUAAFgnaQOKy+XSo48+KpfLlehSkhrjOHAYy4HDWA4MxnHgMJaDLykvkgUAABe3pJ1BAQAAFy8CCgAAsA4BBQAAWIeAAgAArJOUAWXVqlW67LLLNGzYMJWUlGjXrl2JLsk6O3bs0I033qj8/Hw5HA69+uqrMe3GGD3yyCPKy8vT8OHDVVpaqo8++iimz4kTJ7Rw4UK53W5lZWVpyZIlOnny5CAeReLV1tbq6quv1siRI5WTk6MFCxaoqakppk9XV5cqKys1atQoZWZmqqKiQq2trTF9mpubVV5eroyMDOXk5OjBBx9Ub2/vYB5KQq1evVrFxcXRh1z5fD5t2rQp2s4Y9t+TTz4ph8Oh++67L7qO8eybxx57TA6HI2aZPHlytJ1xTDCTZNavX2/S09PNf/zHf5gDBw6Yu+66y2RlZZnW1tZEl2aVN954w/zDP/yDeeWVV4wks2HDhpj2J5980ng8HvPqq6+a3//+9+av//qvzfjx481nn30W7TN37lwzffp0895775nf/OY3ZsKECeb2228f5CNJrLKyMvPCCy+Y/fv3m8bGRjN//nxTWFhoTp48Ge1z9913m4KCArN161bz/vvvm1mzZpm/+Iu/iLb39vaaqVOnmtLSUrNnzx7zxhtvmNGjR5uamppEHFJC/PKXvzS/+tWvzH/913+ZpqYm84Mf/MCkpaWZ/fv3G2MYw/7atWuXueyyy0xxcbG59957o+sZz7559NFHzZVXXmlaWlqiy/Hjx6PtjGNiJV1Aueaaa0xlZWX0czgcNvn5+aa2tjaBVdntiwElEokYr9drnnrqqei69vZ243K5zE9/+lNjjDEHDx40kszu3bujfTZt2mQcDof57//+70Gr3TZtbW1GkqmrqzPGfD5uaWlp5uWXX472+eCDD4wkU19fb4z5PCw6nU7j9/ujfVavXm3cbrcJhUKDewAWueSSS8y///u/M4b91NHRYSZOnGi2bNli/vIv/zIaUBjPvnv00UfN9OnTT9vGOCZeUp3i6e7uVkNDg0pLS6PrnE6nSktLVV9fn8DKksuRI0fk9/tjxtHj8aikpCQ6jvX19crKytLMmTOjfUpLS+V0OrVz585Br9kWgUBAkpSdnS1JamhoUE9PT8xYTp48WYWFhTFjOW3aNOXm5kb7lJWVKRgM6sCBA4NYvR3C4bDWr1+vzs5O+Xw+xrCfKisrVV5eHjNuEr+T8froo4+Un5+vyy+/XAsXLlRzc7MkxtEGSfWywD/96U8Kh8MxvwySlJubqw8//DBBVSUfv98vSacdx1Ntfr9fOTk5Me2pqanKzs6O9hlqIpGI7rvvPl177bWaOnWqpM/HKT09XVlZWTF9vziWpxvrU21Dxb59++Tz+dTV1aXMzExt2LBBRUVFamxsZAzjtH79ev3ud7/T7t27v9TG72TflZSUaO3atZo0aZJaWlq0YsUKfeMb39D+/fsZRwskVUABEqmyslL79+/XO++8k+hSktKkSZPU2NioQCCgX/ziF1q8eLHq6uoSXVbSOXr0qO69915t2bJFw4YNS3Q5SW3evHnRn4uLi1VSUqJx48bp5z//uYYPH57AyiAl2V08o0ePVkpKypeuom5tbZXX601QVcnn1FidbRy9Xq/a2tpi2nt7e3XixIkhOdZVVVXauHGj3n77bY0dOza63uv1qru7W+3t7TH9vziWpxvrU21DRXp6uiZMmKAZM2aotrZW06dP17PPPssYxqmhoUFtbW362te+ptTUVKWmpqqurk7PPfecUlNTlZuby3j2U1ZWlr7yla/o0KFD/F5aIKkCSnp6umbMmKGtW7dG10UiEW3dulU+ny+BlSWX8ePHy+v1xoxjMBjUzp07o+Po8/nU3t6uhoaGaJ9t27YpEomopKRk0GtOFGOMqqqqtGHDBm3btk3jx4+PaZ8xY4bS0tJixrKpqUnNzc0xY7lv376YwLdlyxa53W4VFRUNzoFYKBKJKBQKMYZxmj17tvbt26fGxsboMnPmTC1cuDD6M+PZPydPntThw4eVl5fH76UNEn2VbrzWr19vXC6XWbt2rTl48KBZunSpycrKirmKGp9f4b9nzx6zZ88eI8n867/+q9mzZ4/54x//aIz5/DbjrKws89prr5m9e/eam2666bS3GX/1q181O3fuNO+8846ZOHHikLvNeNmyZcbj8Zjt27fH3Ir46aefRvvcfffdprCw0Gzbts28//77xufzGZ/PF20/dSvinDlzTGNjo9m8ebMZM2bMkLoV8aGHHjJ1dXXmyJEjZu/eveahhx4yDofD/PrXvzbGMIbn68/v4jGG8eyrBx54wGzfvt0cOXLE/Pa3vzWlpaVm9OjRpq2tzRjDOCZa0gUUY4z58Y9/bAoLC016erq55pprzHvvvZfokqzz9ttvG0lfWhYvXmyM+fxW4x/+8IcmNzfXuFwuM3v2bNPU1BSzjU8++cTcfvvtJjMz07jdbvOd73zHdHR0JOBoEud0YyjJvPDCC9E+n332mfnud79rLrnkEpORkWH+5m/+xrS0tMRs5w9/+IOZN2+eGT58uBk9erR54IEHTE9PzyAfTeL8/d//vRk3bpxJT083Y8aMMbNnz46GE2MYw/P1xYDCePbNrbfeavLy8kx6erq59NJLza233moOHToUbWccE8thjDGJmbsBAAA4vaS6BgUAAAwNBBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWOf/AbfZ/uUjpPReAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "#打印游戏\n",
    "def show():\n",
    "    plt.imshow(env.render())\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[0.3000, 0.7000],\n",
       "         [0.4817, 0.5183]], grad_fn=<SoftmaxBackward0>),\n",
       " tensor([[0.3353],\n",
       "         [0.1960]], grad_fn=<AddmmBackward0>))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "#定义模型\n",
    "model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 2),\n",
    "    torch.nn.Softmax(dim=1),\n",
    ")\n",
    "\n",
    "model_td = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 1),\n",
    ")\n",
    "\n",
    "model(torch.randn(2, 4)), model_td(torch.randn(2, 4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "\n",
    "#得到一个动作\n",
    "def get_action(state):\n",
    "    state = torch.FloatTensor(state).reshape(1, 4)\n",
    "    #[1, 4] -> [1, 2]\n",
    "    prob = model(state)\n",
    "\n",
    "    #根据概率选择一个动作\n",
    "    action = random.choices(range(2), weights=prob[0].tolist(), k=1)[0]\n",
    "\n",
    "    return action\n",
    "\n",
    "\n",
    "get_action([1, 2, 3, 4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1853/2726165283.py:31: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at  ../torch/csrc/utils/tensor_new.cpp:201.)\n",
      "  states = torch.FloatTensor(states).reshape(-1, 4)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor([[ 2.4113e-02,  4.5527e-03,  3.7928e-02, -2.6878e-02],\n",
       "         [ 2.4204e-02, -1.9109e-01,  3.7390e-02,  2.7753e-01],\n",
       "         [ 2.0382e-02,  3.4771e-03,  4.2941e-02, -3.1339e-03],\n",
       "         [ 2.0451e-02,  1.9796e-01,  4.2878e-02, -2.8197e-01],\n",
       "         [ 2.4410e-02,  2.2513e-03,  3.7239e-02,  2.3927e-02],\n",
       "         [ 2.4455e-02,  1.9682e-01,  3.7717e-02, -2.5678e-01],\n",
       "         [ 2.8392e-02,  1.1804e-03,  3.2582e-02,  4.7559e-02],\n",
       "         [ 2.8415e-02,  1.9582e-01,  3.3533e-02, -2.3467e-01],\n",
       "         [ 3.2332e-02,  2.3573e-04,  2.8839e-02,  6.8400e-02],\n",
       "         [ 3.2337e-02, -1.9529e-01,  3.0207e-02,  3.7004e-01],\n",
       "         [ 2.8431e-02, -3.9083e-01,  3.7608e-02,  6.7209e-01],\n",
       "         [ 2.0614e-02, -5.8645e-01,  5.1050e-02,  9.7638e-01],\n",
       "         [ 8.8853e-03, -7.8222e-01,  7.0578e-02,  1.2846e+00],\n",
       "         [-6.7590e-03, -9.7816e-01,  9.6271e-02,  1.5986e+00],\n",
       "         [-2.6322e-02, -7.8430e-01,  1.2824e-01,  1.3374e+00],\n",
       "         [-4.2008e-02, -9.8079e-01,  1.5499e-01,  1.6673e+00],\n",
       "         [-6.1624e-02, -7.8777e-01,  1.8834e-01,  1.4266e+00]]),\n",
       " tensor([[1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.]]),\n",
       " tensor([[0],\n",
       "         [1],\n",
       "         [1],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0]]),\n",
       " tensor([[ 2.4204e-02, -1.9109e-01,  3.7390e-02,  2.7753e-01],\n",
       "         [ 2.0382e-02,  3.4771e-03,  4.2941e-02, -3.1339e-03],\n",
       "         [ 2.0451e-02,  1.9796e-01,  4.2878e-02, -2.8197e-01],\n",
       "         [ 2.4410e-02,  2.2513e-03,  3.7239e-02,  2.3927e-02],\n",
       "         [ 2.4455e-02,  1.9682e-01,  3.7717e-02, -2.5678e-01],\n",
       "         [ 2.8392e-02,  1.1804e-03,  3.2582e-02,  4.7559e-02],\n",
       "         [ 2.8415e-02,  1.9582e-01,  3.3533e-02, -2.3467e-01],\n",
       "         [ 3.2332e-02,  2.3573e-04,  2.8839e-02,  6.8400e-02],\n",
       "         [ 3.2337e-02, -1.9529e-01,  3.0207e-02,  3.7004e-01],\n",
       "         [ 2.8431e-02, -3.9083e-01,  3.7608e-02,  6.7209e-01],\n",
       "         [ 2.0614e-02, -5.8645e-01,  5.1050e-02,  9.7638e-01],\n",
       "         [ 8.8853e-03, -7.8222e-01,  7.0578e-02,  1.2846e+00],\n",
       "         [-6.7590e-03, -9.7816e-01,  9.6271e-02,  1.5986e+00],\n",
       "         [-2.6322e-02, -7.8430e-01,  1.2824e-01,  1.3374e+00],\n",
       "         [-4.2008e-02, -9.8079e-01,  1.5499e-01,  1.6673e+00],\n",
       "         [-6.1624e-02, -7.8777e-01,  1.8834e-01,  1.4266e+00],\n",
       "         [-7.7380e-02, -9.8465e-01,  2.1687e-01,  1.7718e+00]]),\n",
       " tensor([[0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1]]))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_data():\n",
    "    states = []\n",
    "    rewards = []\n",
    "    actions = []\n",
    "    next_states = []\n",
    "    overs = []\n",
    "\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        action = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        next_state, reward, over, _ = env.step(action)\n",
    "\n",
    "        #记录数据样本\n",
    "        states.append(state)\n",
    "        rewards.append(reward)\n",
    "        actions.append(action)\n",
    "        next_states.append(next_state)\n",
    "        overs.append(over)\n",
    "\n",
    "        #更新游戏状态,开始下一个动作\n",
    "        state = next_state\n",
    "\n",
    "    #[b, 4]\n",
    "    states = torch.FloatTensor(states).reshape(-1, 4)\n",
    "    #[b, 1]\n",
    "    rewards = torch.FloatTensor(rewards).reshape(-1, 1)\n",
    "    #[b, 1]\n",
    "    actions = torch.LongTensor(actions).reshape(-1, 1)\n",
    "    #[b, 4]\n",
    "    next_states = torch.FloatTensor(next_states).reshape(-1, 4)\n",
    "    #[b, 1]\n",
    "    overs = torch.LongTensor(overs).reshape(-1, 1)\n",
    "\n",
    "    return states, rewards, actions, next_states, overs\n",
    "\n",
    "\n",
    "get_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "11.0"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "\n",
    "\n",
    "def test(play):\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #记录反馈值的和,这个值越大越好\n",
    "    reward_sum = 0\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        action = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        state, reward, over, _ = env.step(action)\n",
    "        reward_sum += reward\n",
    "\n",
    "        #打印动画\n",
    "        if play and random.random() < 0.2:  #跳帧\n",
    "            display.clear_output(wait=True)\n",
    "            show()\n",
    "\n",
    "    return reward_sum\n",
    "\n",
    "\n",
    "test(play=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[8.090483997483998, 8.690100963999999, 8.260044, 6.724, 4.0]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#优势函数\n",
    "def get_advantages(deltas):\n",
    "    advantages = []\n",
    "\n",
    "    #反向遍历deltas\n",
    "    s = 0.0\n",
    "    for delta in deltas[::-1]:\n",
    "        s = 0.98 * 0.95 * s + delta\n",
    "        advantages.append(s)\n",
    "\n",
    "    #逆序\n",
    "    advantages.reverse()\n",
    "    return advantages\n",
    "\n",
    "\n",
    "get_advantages(range(5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "executionInfo": {
     "elapsed": 8251,
     "status": "ok",
     "timestamp": 1650011468229,
     "user": {
      "displayName": "Sam Lu",
      "userId": "15789059763790170725"
     },
     "user_tz": -480
    },
    "id": "BQXVYW2T_DcQ",
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 16.6\n",
      "50 175.3\n",
      "100 200.0\n",
      "150 191.2\n",
      "200 159.0\n",
      "250 200.0\n",
      "300 200.0\n",
      "350 200.0\n",
      "400 145.0\n",
      "450 200.0\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
    "    optimizer_td = torch.optim.Adam(model_td.parameters(), lr=1e-2)\n",
    "    loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "    #玩N局游戏,每局游戏训练M次\n",
    "    for epoch in range(500):\n",
    "        #玩一局游戏,得到数据\n",
    "        #states -> [b, 4]\n",
    "        #rewards -> [b, 1]\n",
    "        #actions -> [b, 1]\n",
    "        #next_states -> [b, 4]\n",
    "        #overs -> [b, 1]\n",
    "        states, rewards, actions, next_states, overs = get_data()\n",
    "\n",
    "        #计算values和targets\n",
    "        #[b, 4] -> [b, 1]\n",
    "        values = model_td(states)\n",
    "\n",
    "        #[b, 4] -> [b, 1]\n",
    "        targets = model_td(next_states).detach()\n",
    "        targets = targets * 0.98\n",
    "        targets *= (1 - overs)\n",
    "        targets += rewards\n",
    "\n",
    "        #计算优势,这里的advantages有点像是策略梯度里的reward_sum\n",
    "        #只是这里计算的不是reward,而是target和value的差\n",
    "        #[b, 1]\n",
    "        deltas = (targets - values).squeeze(dim=1).tolist()\n",
    "        advantages = get_advantages(deltas)\n",
    "        advantages = torch.FloatTensor(advantages).reshape(-1, 1)\n",
    "\n",
    "        #取出每一步动作的概率\n",
    "        #[b, 2] -> [b, 2] -> [b, 1]\n",
    "        old_probs = model(states)\n",
    "        old_probs = old_probs.gather(dim=1, index=actions)\n",
    "        old_probs = old_probs.detach()\n",
    "\n",
    "        #每批数据反复训练10次\n",
    "        for _ in range(10):\n",
    "            #重新计算每一步动作的概率\n",
    "            #[b, 4] -> [b, 2]\n",
    "            new_probs = model(states)\n",
    "            #[b, 2] -> [b, 1]\n",
    "            new_probs = new_probs.gather(dim=1, index=actions)\n",
    "            new_probs = new_probs\n",
    "\n",
    "            #求出概率的变化\n",
    "            #[b, 1] - [b, 1] -> [b, 1]\n",
    "            ratios = new_probs / old_probs\n",
    "\n",
    "            #计算截断的和不截断的两份loss,取其中小的\n",
    "            #[b, 1] * [b, 1] -> [b, 1]\n",
    "            surr1 = ratios * advantages\n",
    "            #[b, 1] * [b, 1] -> [b, 1]\n",
    "            surr2 = torch.clamp(ratios, 0.8, 1.2) * advantages\n",
    "\n",
    "            loss = -torch.min(surr1, surr2)\n",
    "            loss = loss.mean()\n",
    "\n",
    "            #重新计算value,并计算时序差分loss\n",
    "            values = model_td(states)\n",
    "            loss_td = loss_fn(values, targets)\n",
    "\n",
    "            #更新参数\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            optimizer_td.zero_grad()\n",
    "            loss_td.backward()\n",
    "            optimizer_td.step()\n",
    "\n",
    "        if epoch % 50 == 0:\n",
    "            test_result = sum([test(play=False) for _ in range(10)]) / 10\n",
    "            print(epoch, test_result)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAF7CAYAAAD4/3BBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAoIUlEQVR4nO3df3RU9Z3/8dfk10gIM2mAZJKSIAoFIgRdwDC1UlpSAkRWajyrloXY5cCRTTyFWIrpUhW7x7i4Z/3RVfiju2LPkdLSI7pSwcYgYdWAmJLll6bCoQ2WTEJlMwNoAsl8vn/4ZbqjCExIMp+ZPB/n3HMy9/OZO+/7OXjn5b2fe8dhjDECAACwSEK0CwAAAPg8AgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsE5UA8qzzz6ra6+9Vtdcc40KCwv17rvvRrMcAABgiagFlF/96leqrKzUww8/rN///veaNGmSiouL1dbWFq2SAACAJRzR+rHAwsJCTZ06Vf/+7/8uSQoGg8rNzdX999+vBx98MBolAQAASyRF40PPnTunhoYGVVVVhdYlJCSoqKhI9fX1X+jf2dmpzs7O0OtgMKhTp05p6NChcjgc/VIzAAC4OsYYnT59Wjk5OUpIuPRFnKgElL/85S/q7u5WVlZW2PqsrCx98MEHX+hfXV2tNWvW9Fd5AACgDx0/flwjRoy4ZJ+oBJRIVVVVqbKyMvTa7/crLy9Px48fl8vlimJlAADgSgUCAeXm5mrIkCGX7RuVgDJs2DAlJiaqtbU1bH1ra6s8Hs8X+judTjmdzi+sd7lcBBQAAGLMlUzPiMpdPCkpKZo8ebJqa2tD64LBoGpra+X1eqNREgAAsEjULvFUVlaqrKxMU6ZM0c0336ynnnpKZ8+e1fe///1olQQAACwRtYBy11136eTJk3rooYfk8/l04403avv27V+YOAsAAAaeqD0H5WoEAgG53W75/X7moAAAECMi+f7mt3gAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKzT6wHlkUcekcPhCFvGjRsXau/o6FB5ebmGDh2qtLQ0lZaWqrW1tbfLAAAAMaxPzqDccMMNamlpCS1vvfVWqG3FihV69dVXtXnzZtXV1enEiRO64447+qIMAAAQo5L6ZKNJSfJ4PF9Y7/f79R//8R/auHGjvv3tb0uSnn/+eY0fP167d+/WtGnT+qIcAAAQY/rkDMqHH36onJwcXXfddVqwYIGam5slSQ0NDTp//ryKiopCfceNG6e8vDzV19d/6fY6OzsVCATCFgAAEL96PaAUFhZqw4YN2r59u9atW6djx47p1ltv1enTp+Xz+ZSSkqL09PSw92RlZcnn833pNqurq+V2u0NLbm5ub5cNAAAs0uuXeObMmRP6u6CgQIWFhRo5cqR+/etfa9CgQT3aZlVVlSorK0OvA4EAIQUAgDjW57cZp6en62tf+5qOHDkij8ejc+fOqb29PaxPa2vrReesXOB0OuVyucIWAAAQv/o8oJw5c0ZHjx5Vdna2Jk+erOTkZNXW1obam5qa1NzcLK/X29elAACAGNHrl3h++MMfat68eRo5cqROnDihhx9+WImJibrnnnvkdru1ePFiVVZWKiMjQy6XS/fff7+8Xi938AAAgJBeDygfffSR7rnnHn388ccaPny4vvGNb2j37t0aPny4JOnJJ59UQkKCSktL1dnZqeLiYj333HO9XQYAAIhhDmOMiXYRkQoEAnK73fL7/cxHAQAgRkTy/c1v8QAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArBNxQNm1a5fmzZunnJwcORwOvfzyy2Htxhg99NBDys7O1qBBg1RUVKQPP/wwrM+pU6e0YMECuVwupaena/HixTpz5sxV7QgAAIgfEQeUs2fPatKkSXr22Wcv2r527Vo988wzWr9+vfbs2aPBgweruLhYHR0doT4LFizQoUOHVFNTo61bt2rXrl1aunRpz/cCAADEFYcxxvT4zQ6HtmzZovnz50v67OxJTk6OHnjgAf3whz+UJPn9fmVlZWnDhg26++679f777ys/P1979+7VlClTJEnbt2/X3Llz9dFHHyknJ+eynxsIBOR2u+X3++VyuXpaPgAA6EeRfH/36hyUY8eOyefzqaioKLTO7XarsLBQ9fX1kqT6+nqlp6eHwokkFRUVKSEhQXv27Lnodjs7OxUIBMIWAAAQv3o1oPh8PklSVlZW2PqsrKxQm8/nU2ZmZlh7UlKSMjIyQn0+r7q6Wm63O7Tk5ub2ZtkAAMAyMXEXT1VVlfx+f2g5fvx4tEsCAAB9qFcDisfjkSS1traGrW9tbQ21eTwetbW1hbV3dXXp1KlToT6f53Q65XK5whYAABC/ejWgjBo1Sh6PR7W1taF1gUBAe/bskdfrlSR5vV61t7eroaEh1GfHjh0KBoMqLCzszXIAAECMSor0DWfOnNGRI0dCr48dO6bGxkZlZGQoLy9Py5cv1z//8z9rzJgxGjVqlH7yk58oJycndKfP+PHjNXv2bC1ZskTr16/X+fPnVVFRobvvvvuK7uABAADxL+KA8t577+lb3/pW6HVlZaUkqaysTBs2bNCPfvQjnT17VkuXLlV7e7u+8Y1vaPv27brmmmtC73nxxRdVUVGhmTNnKiEhQaWlpXrmmWd6YXcAAEA8uKrnoEQLz0EBACD2RO05KAAAAL2BgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoRB5Rdu3Zp3rx5ysnJkcPh0MsvvxzWfu+998rhcIQts2fPDutz6tQpLViwQC6XS+np6Vq8eLHOnDlzVTsCAADiR8QB5ezZs5o0aZKeffbZL+0ze/ZstbS0hJZf/vKXYe0LFizQoUOHVFNTo61bt2rXrl1aunRp5NUDAIC4lBTpG+bMmaM5c+Zcso/T6ZTH47lo2/vvv6/t27dr7969mjJliiTpZz/7mebOnat//dd/VU5OTqQlAQCAONMnc1B27typzMxMjR07VsuWLdPHH38caquvr1d6enoonEhSUVGREhIStGfPnotur7OzU4FAIGwBAADxq9cDyuzZs/WLX/xCtbW1+pd/+RfV1dVpzpw56u7uliT5fD5lZmaGvScpKUkZGRny+XwX3WZ1dbXcbndoyc3N7e2yAQCARSK+xHM5d999d+jviRMnqqCgQNdff7127typmTNn9mibVVVVqqysDL0OBAKEFAAA4lif32Z83XXXadiwYTpy5IgkyePxqK2tLaxPV1eXTp069aXzVpxOp1wuV9gCAADiV58HlI8++kgff/yxsrOzJUler1ft7e1qaGgI9dmxY4eCwaAKCwv7uhwAABADIr7Ec+bMmdDZEEk6duyYGhsblZGRoYyMDK1Zs0alpaXyeDw6evSofvSjH2n06NEqLi6WJI0fP16zZ8/WkiVLtH79ep0/f14VFRW6++67uYMHAABIkhzGGBPJG3bu3KlvfetbX1hfVlamdevWaf78+dq3b5/a29uVk5OjWbNm6ac//amysrJCfU+dOqWKigq9+uqrSkhIUGlpqZ555hmlpaVdUQ2BQEBut1t+v5/LPQAAxIhIvr8jDig2IKAAABB7Ivn+5rd4AACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6Ef9YIAD0JRMM6sjvntPlfoXj+qIlSky+pp+qAtDfCCgALGPU3nxQMsFL9gp2dSkhycjhcPRTXQD6E5d4AMQkY7qjXQKAPkRAARCTTJCAAsQzAgqA2BS89CUgALGNgAIgJhkCChDXCCgAYhJzUID4RkABEJOYgwLENwIKgJhEQAHiGwEFQExiDgoQ3wgoAGISc1CA+EZAARCTOIMCxDcCCoCYxBwUIL4RUADEJC7xAPGNgAIgJnGJB4hvBBQAsYkzKEBcI6AAiEnMQQHiGwEFQEwy3VziAeIZAQVATGKSLBDfCCgAYhKTZIH4FlFAqa6u1tSpUzVkyBBlZmZq/vz5ampqCuvT0dGh8vJyDR06VGlpaSotLVVra2tYn+bmZpWUlCg1NVWZmZlauXKlurq6rn5vAAwYzEEB4ltEAaWurk7l5eXavXu3ampqdP78ec2aNUtnz54N9VmxYoVeffVVbd68WXV1dTpx4oTuuOOOUHt3d7dKSkp07tw5vfPOO3rhhRe0YcMGPfTQQ723VwDin+EMChDPHMYY09M3nzx5UpmZmaqrq9P06dPl9/s1fPhwbdy4UXfeeack6YMPPtD48eNVX1+vadOmadu2bbrtttt04sQJZWVlSZLWr1+vVatW6eTJk0pJSbns5wYCAbndbvn9frlcrp6WD8BCJtit935eftkAkvf1u5Q54dtyOBz9VBmAqxXJ9/dVzUHx+/2SpIyMDElSQ0ODzp8/r6KiolCfcePGKS8vT/X19ZKk+vp6TZw4MRROJKm4uFiBQECHDh266Od0dnYqEAiELQAGtiBzUIC41uOAEgwGtXz5ct1yyy2aMGGCJMnn8yklJUXp6elhfbOysuTz+UJ9/m84udB+oe1iqqur5Xa7Q0tubm5PywYQJ5iDAsS3HgeU8vJyHTx4UJs2berNei6qqqpKfr8/tBw/frzPPxOA3QgoQHzrUUCpqKjQ1q1b9eabb2rEiBGh9R6PR+fOnVN7e3tY/9bWVnk8nlCfz9/Vc+H1hT6f53Q65XK5whYA8SvJmXrZPl0dp/uhEgDRElFAMcaooqJCW7Zs0Y4dOzRq1Kiw9smTJys5OVm1tbWhdU1NTWpubpbX65Ukeb1eHThwQG1tbaE+NTU1crlcys/Pv5p9ARAPHA4N/drXL9vt4z/U90MxAKIlKZLO5eXl2rhxo1555RUNGTIkNGfE7XZr0KBBcrvdWrx4sSorK5WRkSGXy6X7779fXq9X06ZNkyTNmjVL+fn5WrhwodauXSufz6fVq1ervLxcTqez9/cQQMxxJPAMSWCgiyigrFu3TpI0Y8aMsPXPP/+87r33XknSk08+qYSEBJWWlqqzs1PFxcV67rnnQn0TExO1detWLVu2TF6vV4MHD1ZZWZkeffTRq9sTAHHDkZAY7RIARNlVPQclWngOChC/jAnqRMNvdaLh1Uv2S0wZpJvufYrnoAAxpN+egwIAvc/BGRQABBQA9mEOCgCOAgCs43BwBgUY6AgoAKzDJR4ABBQA1uESDwCOAgDswxkUYMAjoACwjsPBoQkY6DgKALAOl3gAcBQAYB0myQIgoACwD7cZAwMeAQWAdbjEA4CjAADrMEkWAEcBANZhDgoAAgoA6xBQABBQAFiHOSgAOAoAsA5zUABwFABgHS7xACCgALAOl3gAcBQAYB0HD2oDBjwCCgD7cAYFGPA4CgCwDnNQABBQAFjF4XBwFw8AAgqAGGWMjAlGuwoAfYSAAiB2BQkoQLwioACISUbiDAoQxwgoAGKWCXZHuwQAfYSAAiBmcQYFiF8EFAAxyzAHBYhbBBQAMcpInEEB4hYBBUDMYg4KEL8IKABiFnNQgPgVUUCprq7W1KlTNWTIEGVmZmr+/PlqamoK6zNjxoz//yTIvy733XdfWJ/m5maVlJQoNTVVmZmZWrlypbq6uq5+bwAMHIYzKEA8S4qkc11dncrLyzV16lR1dXXpxz/+sWbNmqXDhw9r8ODBoX5LlizRo48+Gnqdmpoa+ru7u1slJSXyeDx655131NLSokWLFik5OVmPPfZYL+wSgIGCMyhA/IoooGzfvj3s9YYNG5SZmamGhgZNnz49tD41NVUej+ei2/jd736nw4cP64033lBWVpZuvPFG/fSnP9WqVav0yCOPKCUlpQe7AWAg4i4eIH5d1RwUv98vScrIyAhb/+KLL2rYsGGaMGGCqqqq9Mknn4Ta6uvrNXHiRGVlZYXWFRcXKxAI6NChQxf9nM7OTgUCgbAFwEBnJMMlHiBeRXQG5f8KBoNavny5brnlFk2YMCG0/nvf+55GjhypnJwc7d+/X6tWrVJTU5NeeuklSZLP5wsLJ5JCr30+30U/q7q6WmvWrOlpqQDiFGdQgPjV44BSXl6ugwcP6q233gpbv3Tp0tDfEydOVHZ2tmbOnKmjR4/q+uuv79FnVVVVqbKyMvQ6EAgoNze3Z4UDiBsEFCB+9egST0VFhbZu3ao333xTI0aMuGTfwsJCSdKRI0ckSR6PR62trWF9Lrz+snkrTqdTLpcrbAEAJskC8SuigGKMUUVFhbZs2aIdO3Zo1KhRl31PY2OjJCk7O1uS5PV6deDAAbW1tYX61NTUyOVyKT8/P5JyAAxw3GYMxK+ILvGUl5dr48aNeuWVVzRkyJDQnBG3261Bgwbp6NGj2rhxo+bOnauhQ4dq//79WrFihaZPn66CggJJ0qxZs5Sfn6+FCxdq7dq18vl8Wr16tcrLy+V0Ont/DwHEJ2MkLvEAcSuiMyjr1q2T3+/XjBkzlJ2dHVp+9atfSZJSUlL0xhtvaNasWRo3bpweeOABlZaW6tVXXw1tIzExUVu3blViYqK8Xq/+/u//XosWLQp7bgoAXI6RZLiLB4hbEZ1BMcZcsj03N1d1dXWX3c7IkSP12muvRfLRAPAFzEEB4he/xQMgZnEXDxC/CCgAYhcBBYhbBBQAMYtLPED8IqAAiFGG24yBOEZAARCzOIMCxC8CCgDrJDlTlX7tjZfsY7q79PEf6vunIAD9joACwD4OhxKSUi7bLdh1vh+KARANBBQAFnLI4eDwBAxkHAEA2CmBwxMwkHEEAGAfB2dQgIGOIwAA6zgkAgowwHEEAGAnAgowoHEEAGAhhxzMQQEGNI4AAOzjEAEFGOA4AgCwEJNkgYGOIwAAKxFQgIGNIwAA+zgcPAcFGOA4AgCwE2dQgAGNIwAA6/AcFABJ0S4AQHzq6urq8XuD3d0y5vL9jMxVfY4kJSYmyuFwXNU2APQ+AgqAPpGXl6eTJ0/26L3JSQn6u2/mq/y7Uy/Z7909e3Tj3w3q0WdcsGPHDt16661XtQ0AvY+AAqBPdHV19fzshknQ+St4rzFXfwbFXMmpGgD9joACwErdwb8Gh5PnRijQNVRBJWpQwmllpvxJKQnnolgdgL5GQAFgHWOMuoNBSVLT2alq6bxeHcHBMnIo2dGpjzrGaYp7W5SrBNCXmCYPwErdQenIJzfpj58W6NOgS0aJkhJ03gzS/3Zl6+32OxQ0HMKAeMV/3QCsYySd6BipDz+ZoqASL9rnk2639vjn9W9hAPoNAQWAfYwUDBp99kSUL+O4TDuAWEZAAWAdo7/OQQEwMBFQANjHhN/FA2DgIaAAsI6RNCzpjxo1qFHSxc+kOB2faIr7tf4sC0A/iiigrFu3TgUFBXK5XHK5XPJ6vdq27a+3+nV0dKi8vFxDhw5VWlqaSktL1draGraN5uZmlZSUKDU1VZmZmVq5cuVVP2gJQPwxwW6NTX1Xede8rxTHJ3IoKMko0XFOaYmnNP0rv1Kyg2ehAPEqouegjBgxQo8//rjGjBkjY4xeeOEF3X777dq3b59uuOEGrVixQr/97W+1efNmud1uVVRU6I477tDbb78tSeru7lZJSYk8Ho/eeecdtbS0aNGiRUpOTtZjjz3WJzsIIDb9+S8BvfL2B5I+UEvnKLV3ZSloEjU40a8c5xFtTejQn08Gol0mgD7iMFf5nOeMjAw98cQTuvPOOzV8+HBt3LhRd955pyTpgw8+0Pjx41VfX69p06Zp27Ztuu2223TixAllZWVJktavX69Vq1bp5MmTSklJuaLPDAQCcrvduvfee6/4PQD61y9+8Qt1dHREu4zLmjdvnrKzs6NdBjAgnDt3Ths2bJDf75fL5bpk3x4/Sba7u1ubN2/W2bNn5fV61dDQoPPnz6uoqCjUZ9y4ccrLywsFlPr6ek2cODEUTiSpuLhYy5Yt06FDh3TTTTdd9LM6OzvV2dkZeh0IfPZ/TQsXLlRaWlpPdwFAH9q8eXNMBJSSkpIvPfYA6F1nzpzRhg0brqhvxAHlwIED8nq96ujoUFpamrZs2aL8/Hw1NjYqJSVF6enpYf2zsrLk8/kkST6fLyycXGi/0PZlqqurtWbNmi+snzJlymUTGIDoSEqKjV/SGD9+vG6++eZolwEMCBdOMFyJiO/iGTt2rBobG7Vnzx4tW7ZMZWVlOnz4cKSbiUhVVZX8fn9oOX78eJ9+HgAAiK6I/xcnJSVFo0ePliRNnjxZe/fu1dNPP6277rpL586dU3t7e9hZlNbWVnk8HkmSx+PRu+++G7a9C3f5XOhzMU6nU06nM9JSAQBAjLrq56AEg0F1dnZq8uTJSk5OVm1tbaitqalJzc3N8nq9kiSv16sDBw6ora0t1KempkYul0v5+flXWwoAAIgTEZ1Bqaqq0pw5c5SXl6fTp09r48aN2rlzp15//XW53W4tXrxYlZWVysjIkMvl0v333y+v16tp06ZJkmbNmqX8/HwtXLhQa9eulc/n0+rVq1VeXs4ZEgAAEBJRQGlra9OiRYvU0tIit9utgoICvf766/rOd74jSXryySeVkJCg0tJSdXZ2qri4WM8991zo/YmJidq6dauWLVsmr9erwYMHq6ysTI8++mjv7hUAAIhpV/0clGi48ByUK7mPGkB0ZGZm6uTJk9Eu47Lq6uo0ffr0aJcBDAiRfH/zWzwAAMA6BBQAAGAdAgoAALAOAQUAAFgnNp5FDSDmzJ07V36/P9plXNbQoUOjXQKAiyCgAOgTV/qDYABwMVziAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArBNRQFm3bp0KCgrkcrnkcrnk9Xq1bdu2UPuMGTPkcDjClvvuuy9sG83NzSopKVFqaqoyMzO1cuVKdXV19c7eAACAuJAUSecRI0bo8ccf15gxY2SM0QsvvKDbb79d+/bt0w033CBJWrJkiR599NHQe1JTU0N/d3d3q6SkRB6PR++8845aWlq0aNEiJScn67HHHuulXQIAALHOYYwxV7OBjIwMPfHEE1q8eLFmzJihG2+8UU899dRF+27btk233XabTpw4oaysLEnS+vXrtWrVKp08eVIpKSlX9JmBQEBut1t+v18ul+tqygcAAP0kku/vHs9B6e7u1qZNm3T27Fl5vd7Q+hdffFHDhg3ThAkTVFVVpU8++STUVl9fr4kTJ4bCiSQVFxcrEAjo0KFDX/pZnZ2dCgQCYQsAAIhfEV3ikaQDBw7I6/Wqo6NDaWlp2rJli/Lz8yVJ3/ve9zRy5Ejl5ORo//79WrVqlZqamvTSSy9Jknw+X1g4kRR67fP5vvQzq6urtWbNmkhLBQAAMSrigDJ27Fg1NjbK7/frN7/5jcrKylRXV6f8/HwtXbo01G/ixInKzs7WzJkzdfToUV1//fU9LrKqqkqVlZWh14FAQLm5uT3eHgAAsFvEl3hSUlI0evRoTZ48WdXV1Zo0aZKefvrpi/YtLCyUJB05ckSS5PF41NraGtbnwmuPx/Oln+l0OkN3Dl1YAABA/Lrq56AEg0F1dnZetK2xsVGSlJ2dLUnyer06cOCA2traQn1qamrkcrlCl4kAAAAiusRTVVWlOXPmKC8vT6dPn9bGjRu1c+dOvf766zp69Kg2btyouXPnaujQodq/f79WrFih6dOnq6CgQJI0a9Ys5efna+HChVq7dq18Pp9Wr16t8vJyOZ3OPtlBAAAQeyIKKG1tbVq0aJFaWlrkdrtVUFCg119/Xd/5znd0/PhxvfHGG3rqqad09uxZ5ebmqrS0VKtXrw69PzExUVu3btWyZcvk9Xo1ePBglZWVhT03BQAA4KqfgxINPAcFAIDY0y/PQQEAAOgrBBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDpJ0S6gJ4wxkqRAIBDlSgAAwJW68L194Xv8UmIyoJw+fVqSlJubG+VKAABApE6fPi23233JPg5zJTHGMsFgUE1NTcrPz9fx48flcrmiXVLMCgQCys3NZRx7AWPZexjL3sE49h7GsncYY3T69Gnl5OQoIeHSs0xi8gxKQkKCvvrVr0qSXC4X/1h6AePYexjL3sNY9g7GsfcwllfvcmdOLmCSLAAAsA4BBQAAWCdmA4rT6dTDDz8sp9MZ7VJiGuPYexjL3sNY9g7Gsfcwlv0vJifJAgCA+BazZ1AAAED8IqAAAADrEFAAAIB1CCgAAMA6MRlQnn32WV177bW65pprVFhYqHfffTfaJVln165dmjdvnnJycuRwOPTyyy+HtRtj9NBDDyk7O1uDBg1SUVGRPvzww7A+p06d0oIFC+RyuZSenq7FixfrzJkz/bgX0VddXa2pU6dqyJAhyszM1Pz589XU1BTWp6OjQ+Xl5Ro6dKjS0tJUWlqq1tbWsD7Nzc0qKSlRamqqMjMztXLlSnV1dfXnrkTVunXrVFBQEHrIldfr1bZt20LtjGHPPf7443I4HFq+fHloHeN5ZR555BE5HI6wZdy4caF2xjHKTIzZtGmTSUlJMf/5n/9pDh06ZJYsWWLS09NNa2trtEuzymuvvWb+6Z/+ybz00ktGktmyZUtY++OPP27cbrd5+eWXzf/8z/+Yv/3bvzWjRo0yn376aajP7NmzzaRJk8zu3bvNf//3f5vRo0ebe+65p5/3JLqKi4vN888/bw4ePGgaGxvN3LlzTV5enjlz5kyoz3333Wdyc3NNbW2tee+998y0adPM17/+9VB7V1eXmTBhgikqKjL79u0zr732mhk2bJipqqqKxi5FxX/913+Z3/72t+YPf/iDaWpqMj/+8Y9NcnKyOXjwoDGGMeypd99911x77bWmoKDA/OAHPwitZzyvzMMPP2xuuOEG09LSElpOnjwZamccoyvmAsrNN99sysvLQ6+7u7tNTk6Oqa6ujmJVdvt8QAkGg8bj8ZgnnngitK69vd04nU7zy1/+0hhjzOHDh40ks3fv3lCfbdu2GYfDYf785z/3W+22aWtrM5JMXV2dMeazcUtOTjabN28O9Xn//feNJFNfX2+M+SwsJiQkGJ/PF+qzbt0643K5TGdnZ//ugEW+8pWvmJ///OeMYQ+dPn3ajBkzxtTU1JhvfvOboYDCeF65hx9+2EyaNOmibYxj9MXUJZ5z586poaFBRUVFoXUJCQkqKipSfX19FCuLLceOHZPP5wsbR7fbrcLCwtA41tfXKz09XVOmTAn1KSoqUkJCgvbs2dPvNdvC7/dLkjIyMiRJDQ0NOn/+fNhYjhs3Tnl5eWFjOXHiRGVlZYX6FBcXKxAI6NChQ/1YvR26u7u1adMmnT17Vl6vlzHsofLycpWUlISNm8S/yUh9+OGHysnJ0XXXXacFCxaoublZEuNog5j6scC//OUv6u7uDvvHIElZWVn64IMPolRV7PH5fJJ00XG80Obz+ZSZmRnWnpSUpIyMjFCfgSYYDGr58uW65ZZbNGHCBEmfjVNKSorS09PD+n5+LC821hfaBooDBw7I6/Wqo6NDaWlp2rJli/Lz89XY2MgYRmjTpk36/e9/r717936hjX+TV66wsFAbNmzQ2LFj1dLSojVr1ujWW2/VwYMHGUcLxFRAAaKpvLxcBw8e1FtvvRXtUmLS2LFj1djYKL/fr9/85jcqKytTXV1dtMuKOcePH9cPfvAD1dTU6Jprrol2OTFtzpw5ob8LCgpUWFiokSNH6te//rUGDRoUxcogxdhdPMOGDVNiYuIXZlG3trbK4/FEqarYc2GsLjWOHo9HbW1tYe1dXV06derUgBzriooKbd26VW+++aZGjBgRWu/xeHTu3Dm1t7eH9f/8WF5srC+0DRQpKSkaPXq0Jk+erOrqak2aNElPP/00YxihhoYGtbW16W/+5m+UlJSkpKQk1dXV6ZlnnlFSUpKysrIYzx5KT0/X1772NR05coR/lxaIqYCSkpKiyZMnq7a2NrQuGAyqtrZWXq83ipXFllGjRsnj8YSNYyAQ0J49e0Lj6PV61d7eroaGhlCfHTt2KBgMqrCwsN9rjhZjjCoqKrRlyxbt2LFDo0aNCmufPHmykpOTw8ayqalJzc3NYWN54MCBsMBXU1Mjl8ul/Pz8/tkRCwWDQXV2djKGEZo5c6YOHDigxsbG0DJlyhQtWLAg9Dfj2TNnzpzR0aNHlZ2dzb9LG0R7lm6kNm3aZJxOp9mwYYM5fPiwWbp0qUlPTw+bRY3PZvjv27fP7Nu3z0gy//Zv/2b27dtn/vSnPxljPrvNOD093bzyyitm//795vbbb7/obcY33XST2bNnj3nrrbfMmDFjBtxtxsuWLTNut9vs3Lkz7FbETz75JNTnvvvuM3l5eWbHjh3mvffeM16v13i93lD7hVsRZ82aZRobG8327dvN8OHDB9StiA8++KCpq6szx44dM/v37zcPPvigcTgc5ne/+50xhjG8Wv/3Lh5jGM8r9cADD5idO3eaY8eOmbffftsUFRWZYcOGmba2NmMM4xhtMRdQjDHmZz/7mcnLyzMpKSnm5ptvNrt37452SdZ58803jaQvLGVlZcaYz241/slPfmKysrKM0+k0M2fONE1NTWHb+Pjjj80999xj0tLSjMvlMt///vfN6dOno7A30XOxMZRknn/++VCfTz/91PzjP/6j+cpXvmJSU1PNd7/7XdPS0hK2nT/+8Y9mzpw5ZtCgQWbYsGHmgQceMOfPn+/nvYmef/iHfzAjR440KSkpZvjw4WbmzJmhcGIMY3i1Ph9QGM8rc9ddd5ns7GyTkpJivvrVr5q77rrLHDlyJNTOOEaXwxhjonPuBgAA4OJiag4KAAAYGAgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALDO/wM/3qp+AYSvrAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "200.0"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test(play=True)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第9章-策略梯度算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:pt39]",
   "language": "python",
   "name": "conda-env-pt39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
